Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better treatment for numerical errors #330

Merged
merged 4 commits into from
Oct 9, 2024

Conversation

guilhermebodin
Copy link
Member

@guilhermebodin guilhermebodin commented Oct 9, 2024

replace #327
close #313

@raphaelsaavedra
Copy link
Member

raphaelsaavedra commented Oct 9, 2024

I think I actually prefer the approach in #327 (but throwing a proper error instead of just printing) because it doesn't split logdet into two functions, which is probably faster and doesn't lose the logdet advantages:

Logarithm of matrix determinant. Equivalent to log(det(M)), but may provide increased accuracy and avoids overflow/underflow.

I'm not a big fan of using try/catch but in this case it seems fine?

@raphaelsaavedra
Copy link
Member

Thanks for taking care of this btw :)

@guilhermebodin
Copy link
Member Author

I did the benchmarks

using BenchmarkTools

function scalar_with_try_catch(n::Int, F::Float64, v::Float64)
    llk = 0.0
    HALF_LOG_2_PI = 0.5 * log(2*pi)
    for i in 1:n
        try 
            llk -= (
                HALF_LOG_2_PI + 0.5 * (log(F) + v^2 / F)
            )
        catch
            error("Numerical error, F is negative: $F")
        end
    end
    return llk
end

function scalar_with_check(n::Int, F::Float64, v::Float64)
    llk = 0.0
    HALF_LOG_2_PI = 0.5 * log(2*pi)
    for i in 1:n
        if F < 0
            error("Numerical error, F is negative: $F")
        end
        llk -= (
            HALF_LOG_2_PI + 0.5 * (log(F) + v^2 / F)
        )
    end
    return llk
end

n = 1000
F = 1.0
v = 1.0

julia> @btime scalar_with_try_catch($n, $F, $v)
  23.400 μs (0 allocations: 0 bytes)
-1418.9385332046932

julia> @btime scalar_with_check($n, $F, $v)
  5.714 μs (0 allocations: 0 bytes)
-1418.9385332046932

function vetorial_with_try_catch(n::Int, F::Matrix{Float64}, v::Vector{Float64})
    llk = 0.0
    HALF_LOG_2_PI = 0.5 * log(2*pi)
    for i in 1:n
        try 
            llk -=
                HALF_LOG_2_PI + 0.5 * (logdet(F) +
                v' * inv(F) * v)
        catch
            error("Numerical error, F is negative: $(F[i])")
        end
    end
    return llk
end

function vetorial_with_check(n::Int, F::Matrix{Float64}, v::Vector{Float64})
    llk = 0.0
    HALF_LOG_2_PI = 0.5 * log(2*pi)
    for i in 1:n
        detF = det(F)
        if detF < 0
            error("Numerical error, F is negative: $(F[i])")
        end
        llk -=
            HALF_LOG_2_PI + 0.5 * (log(detF) +
            v' * inv(F) * v)
    end
    return llk
end

function vetorial_with_check_and_logdet(n::Int, F::Matrix{Float64}, v::Vector{Float64})
    llk = 0.0
    HALF_LOG_2_PI = 0.5 * log(2*pi)
    for i in 1:n
        detF = det(F)
        if detF < 0
            error("Numerical error, F is negative: $(F[i])")
        end
        llk -=
            HALF_LOG_2_PI + 0.5 * (logdet(F) +
            v' * inv(F) * v)
    end
    return llk
end

n = 1000
F = Matrix{Float64}(I, 3, 3)
v = ones(3)

julia> GC.gc()

julia> GC.gc()

julia> @btime vetorial_with_try_catch($n, $F, $v)
  3.747 ms (4000 allocations: 218.75 KiB)
-2418.9385332047

julia> GC.gc()

julia> GC.gc()

julia> @btime vetorial_with_check($n, $F, $v) # might have been a little unlucky in this one
  8.012 ms (6000 allocations: 296.88 KiB)
-2418.9385332047

julia> GC.gc()

julia> GC.gc()

julia> @btime vetorial_with_check_and_logdet($n, $F, $v)
  6.133 ms (6000 allocations: 296.88 KiB)
-2418.9385332047

In the scalar case it makes more difference, not using the try catch, in the vetorial case my approach actually hurts performance.

I found this post https://discourse.julialang.org/t/try-catch-statement-really-slow-even-if-catch-is-never-executed/91744/6 revealing that each try means a few nanoseconds.

I am going to change it bakc to try catch statements

@raphaelsaavedra
Copy link
Member

raphaelsaavedra commented Oct 9, 2024

What about moving the try block outside the for loop? It doesn't matter in which step it fails, it only matters that it fails:

function scalar_with_try_catch(n::Int, F::Float64, v::Float64)
    llk = 0.0
    HALF_LOG_2_PI = 0.5 * log(2*pi)
    try
        for i in 1:n
            llk -= (
                HALF_LOG_2_PI + 0.5 * (log(F) + v^2 / F)
            )
        end
    catch
        error("Numerical error, F is negative")
    end
    return llk
end

@guilhermebodin
Copy link
Member Author

This was a test to mimic the package behavior, realistically we would have to put the try-catch blocks in the filter recursions.

function filter_recursions!(
    kalman_state::MultivariateKalmanState{Fl},
    sys::LinearMultivariateTimeInvariant,
    steadystate_tol::Fl,
    skip_llk_instants::Int,
) where Fl
    RQR = sys.R * sys.Q * sys.R'
    @inbounds for t in 1:size(sys.y, 1)
        update_kalman_state!(
            kalman_state,
            sys.y[t, :],
            sys.Z,
            sys.T,
            sys.H,
            RQR,
            sys.d,
            sys.c,
            skip_llk_instants,
            steadystate_tol,
            t,
        )
    end
    return kalman_state.llk
end

It is also a viable option.

Comment on lines 65 to 67
catch
@error("Numerical error in the log-likelihood calculation. F = $(kalman_state.F), v = $(kalman_state.v). F can only be positive.")
rethrow()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if this is any different than just

catch e
     @error("Numerical error... :, $e)

But looks good to me!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one does not stop the program, only logs to screen

julia> try; log(-10); catch; @error("oi"); rethrow() end
┌ Error: oi
└ @ Main REPL[102]:1
ERROR: DomainError with -10.0:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
Stacktrace:
 [1] throw_complex_domainerror(f::Symbol, x::Float64)
   @ Base.Math .\math.jl:33
 [2] _log
   @ .\special\log.jl:295 [inlined]
 [3] log(x::Float64)
   @ Base.Math .\special\log.jl:261
 [4] log(x::Int64)
   @ Base.Math .\math.jl:1531
 [5] top-level scope
   @ REPL[102]:1

julia> try; log(-10); catch ex; @error("oi", ex) end
┌ Error: oi
│   ex =
│    DomainError with -10.0:
│    log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
└ @ Main REPL[103]:1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's been a while since I programmed in julia... forgot @error just logs! Thanks

@guilhermebodin
Copy link
Member Author

@raphaelsaavedra I moved the try-catch blocks to filter recursions. I think this is the best way to go. If there is any error, it will report the entire Kalman state and tell where the error happened.

@guilhermebodin guilhermebodin merged commit b0fe56c into master Oct 9, 2024
4 checks passed
@guilhermebodin guilhermebodin deleted the gb/treat-numerical-errors branch October 9, 2024 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Getting DomainError when fitting some models
2 participants